This workspace contains a tiny subset (128MB) of the full dataset available (12GB). Feel free to use this workspace to build your project, or to explore a smaller subset with Spark before deploying your cluster on the cloud. Instructions for setting up your Spark cluster is included in the last lesson of the Extracurricular Spark Course content.
You can follow the steps below to guide your data analysis and model building portion of this project.
Introduction
In this project, from the user operation log of music distribution service SPOTIFY
Build a classification model that identifies users who are likely to opt out without being satisfied with the service.
Software Requirements
Data Requirements
Steps
# import libraries
import pyspark
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import isnan,isnull,count, when, col, desc, udf, col, sort_array, asc, avg ,datediff,weekofyear
from pyspark.sql.functions import to_date, from_unixtime
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import sum as Fsum
from pyspark.sql.functions import to_date
from pyspark.sql import Window
from pyspark.sql.types import DateType
import datetime
import numpy as np
import pandas as pd
#%matplotlib inline
import matplotlib.pyplot as plt
import plotly.offline as offline
import plotly.graph_objs as go
offline.init_notebook_mode()
from pyspark.ml import Pipeline
from pyspark.ml.feature import MinMaxScaler,StandardScaler, VectorAssembler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
%%HTML
<style>
div#notebook-container { width: 100%; }
div#menubar-container { width: 65%; }
div#maintoolbar-container { width: 99%; }
</style>
# create a Spark session
spark = SparkSession \
.builder \
.appName("Sparkify Project") \
.getOrCreate()
spark.sparkContext.getConf().getAll()
spark
In this workspace, the mini-dataset file is mini_sparkify_event_data.json. Load and clean the dataset, checking for invalid or missing data - for example, records without userids or sessionids.
path = "mini_sparkify_event_data.json"
user_log = spark.read.json(path)
user_log.show(n=2)
Clean Dataset
print("row :",user_log.select(col('*')).count())
print("sessionId Nan :",user_log.filter(isnan(col('sessionId'))).count())
print("sessionId NULL :",user_log.filter(isnull(col('sessionId'))).count())
print("sessionId \"\" :",user_log.filter((col('sessionId'))=="").count())
print("userId Nan   :",user_log.filter(isnan(col('userId'))).count())
print("userId Null  :",user_log.filter(isnull(col('userId'))).count())
print("userId \"\"   :",user_log.filter((col('userId'))=="").count())
user_log=user_log.filter((col('userId'))!="")
print("row :",user_log.select(col('*')).count())
print("sessionId Nan :",user_log.filter(isnan(col('sessionId'))).count())
print("sessionId NULL :",user_log.filter(isnull(col('sessionId'))).count())
print("sessionId \"\" :",user_log.filter((col('sessionId'))=="").count())
print("userId Nan   :",user_log.filter(isnan(col('userId'))).count())
print("userId Null  :",user_log.filter(isnull(col('userId'))).count())
print("userId \"\"   :",user_log.filter((col('userId'))=="").count())
When you're working with the full dataset, perform EDA by loading a small subset of the data and doing basic manipulations within Spark. In this workspace, you are already provided a small subset of data you can explore.
Once you've done some preliminary analysis, create a column Churn to use as the label for your model. I suggest using the Cancellation Confirmation events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the Downgrade events.
Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.
user_log.printSchema()
explore= user_log.select('status').dropDuplicates().collect()
set(explore)
explore= user_log.select('level').dropDuplicates().collect()
set(explore)
explore= user_log.select('auth').collect()
set(explore)
explore= user_log.select('method').collect()
set(explore)
explore= user_log.select('userAgent').collect()
set(explore)
explore= user_log.select('location').collect()
set(explore)
explore= user_log.select('page').dropDuplicates().collect()
set(explore)
user_log.filter("page = 'Submit Downgrade'").show(3)
user_log.filter("page = 'Cancellation Confirmation'").show(3)
user_log.select(["userId", "firstname","itemInSession" ,"page", "level", "song"]).where(user_log.userId == "143").sort("ts").show(3)
def plot_hist(df,param,title,x_axis,y_axis):
"""
plot histgram by plotly
Input
df :datafram which include column 'Churn' and other parameter
param : parameter name
title : graph title
x_axis: x-axis name
y_axis: y-axis name
"""
Churn1 = go.Histogram(x=df.filter((col('Churn'))==1).toPandas()[param],name="Churn=1", opacity=0.5)
Churn0 = go.Histogram(x=df.filter((col('Churn'))!=1).toPandas()[param],name="Churn=0", opacity=0.5)
layout = go.Layout(
title = title,
xaxis = dict(title=x_axis),
yaxis = dict(title=y_axis),
bargap=0.2,
bargroupgap=0.1,
width=800,
height=300
)
fig = dict(data=[Churn1,Churn0], layout=layout)
offline.iplot(fig, filename=param)
Find Churn user from event "Churancellation Confirmation" and mark the flag "upgrade"
windowval = Window\
.partitionBy("userId")\
.orderBy(desc("ts"))\
.rangeBetween(Window.unboundedPreceding, 0)
flag_Cancellation_event = udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
#flag_downgrade_event = udf(lambda x: 1 if x == "Submit Downgrade" else 2 if x=="Cancellation Confirmation" else 0, IntegerType())
user_log = user_log\
.withColumn("upgrade", flag_Cancellation_event("page"))\
.withColumn("Churn", Fsum("upgrade").over(windowval))
user_log.head()
# add day column change from timestamp to DateStrng
get_day = udf(lambda x: datetime.datetime.fromtimestamp(x/1000.0).strftime('%Y%m%d'))
get_month = udf(lambda x: datetime.datetime.fromtimestamp(x/1000.0).strftime('%Y%m'))
user_log = user_log.withColumn("day", get_day('ts'))
user_log = user_log.withColumn("month", get_month('ts'))
user_log = user_log.withColumn("week", weekofyear(from_unixtime(user_log.ts / 1000.0)))
user_log.show(2)
only two month
user_log.agg({'day':'max'}).show()
user_log.agg({'day':'min'}).show()
We can see that it is unbalanced data from the following
user_Churn=user_log\
.select(['userId','Churn'])\
.dropDuplicates(['userId'])
plot_hist(df=user_Churn,param='Churn',title="Churn ",x_axis="Churn",y_axis="number of user ")
Select a potential variable as the explanatory variable. Some variables are processed so that they are considered to affect the target variable.
user_log.filter(user_log.userId=='143').orderBy('ts').show(3)
gender=user_log\
.select(['userId','gender','Churn'])\
.dropDuplicates(['userId'])\
.replace(['F','M'],['0','1'],'gender')
gender = gender.withColumn('gender', gender.gender.cast('int'))
gender.show(10)
plot_hist(df=gender,param='gender',title="gender ",x_axis="gender",y_axis="number of user ")
level=user_log\
.select(['userId','level','Churn'])\
.dropDuplicates(['userId'])\
.replace(['paid','free'],['0','1'],'level')
level= level.withColumn('level', level.level.cast('int'))
level.show(10)
plot_hist(df=level,param='level',title="level ",x_axis="level",y_axis="number of user ")
#user_log.groupby("userId","sessionId").count().orderBy("userId",user_log.day.cast("float")).groupby("userId").avg().orderBy("userId")
session_song=user_log\
.filter(user_log.page == 'NextSong')\
.select('page','userId','sessionId')\
.groupby("userId","sessionId")\
.agg({'page':'count'})\
.withColumnRenamed('count(page)', 'count_song') \
.sort(('userId'))
session_song.show(10)
session_song=session_song.groupby("userId")\
.agg({'count_song':'avg'})\
.withColumnRenamed('avg(count_song)', 'songs_per_session')
session_song=session_song\
.join(user_Churn,on='userId',how="inner")
session_song.show(10)
plot_hist(df=session_song,param='songs_per_session',title="songs per session ",y_axis="number of user",x_axis="song count per session")
songs_in_day = user_log\
.groupby("userId","day")\
.count()\
.orderBy("userId",user_log.day.cast("float"))
songs_in_day.show(n=10)
songs_in_day=songs_in_day.groupby("userId")\
.agg({'count':'avg'})\
.withColumnRenamed('avg(count)', 'songs_per_day') \
.orderBy('userId')
#songs_in_day.show(n=10)
songs_in_day=songs_in_day\
.join(user_Churn,on='userId',how="inner")
songs_in_day.show(10)
plot_hist(df=songs_in_day,param='songs_per_day',title="songs per day ",y_axis="number of user",x_axis="songs per day")
std_songs_in_day = user_log\
.groupby("userId","day")\
.count()\
.orderBy("userId",user_log.day.cast("float"))
std_songs_in_day.show(n=10)
std_songs_in_day=std_songs_in_day.groupby("userId")\
.agg({'count':'stddev_pop'})\
.withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_day') \
.orderBy('userId')
#songs_in_day.show(n=10)
std_songs_in_day=std_songs_in_day\
.join(user_Churn,on='userId',how="inner")
std_songs_in_day.show(10)
plot_hist(df=std_songs_in_day,param='stddev_songs_per_day',title="songs per day ",y_axis="number of user",x_axis="stddev_songs per day")
songs_in_week = user_log\
.groupby("userId","week")\
.count()\
.orderBy("userId",user_log.week.cast("float"))
songs_in_week.show(n=10)
songs_in_week=songs_in_week.groupby("userId")\
.agg({'count':'avg'})\
.withColumnRenamed('avg(count)', 'songs_per_week') \
.orderBy('userId')
#songs_in_day.show(n=10)
songs_in_week=songs_in_week\
.join(user_Churn,on='userId',how="inner")
songs_in_week.show(10)
plot_hist(df=songs_in_week,param='songs_per_week',title="songs per week ",y_axis="number of user",x_axis="songs per week")
std_songs_in_week = user_log\
.groupby("userId","week")\
.count()\
.orderBy("userId",user_log.week.cast("float"))
std_songs_in_week.show(n=10)
std_songs_in_week=std_songs_in_week.groupby("userId")\
.agg({'count':'stddev_pop'})\
.withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_week') \
.orderBy('userId')
#songs_in_day.show(n=10)
std_songs_in_week=std_songs_in_week\
.join(user_Churn,on='userId',how="inner")
std_songs_in_week.show(10)
plot_hist(df=std_songs_in_week,param='stddev_songs_per_week',title="stddev_songs per week ",y_axis="number of user",x_axis="stddev_songs per week")
songs_in_month = user_log\
.groupby("userId","month")\
.count()\
.orderBy("userId",user_log.month.cast("float"))
songs_in_month.show(n=10)
songs_in_month=songs_in_month.groupby("userId")\
.agg({'count':'avg'})\
.withColumnRenamed('avg(count)', 'songs_per_month') \
.orderBy('userId')
#songs_in_day.show(n=10)
songs_in_month=songs_in_month\
.join(user_Churn,on='userId',how="inner")
songs_in_month.show(10)
plot_hist(df=songs_in_month,param='songs_per_month',title="songs per month ",y_axis="number of user",x_axis="songs per month")
std_songs_in_month = user_log\
.groupby("userId","month")\
.count()\
.orderBy("userId",user_log.month.cast("float"))
std_songs_in_month.show(n=10)
std_songs_in_month=std_songs_in_month.groupby("userId")\
.agg({'count':'stddev_pop'})\
.withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_month') \
.orderBy('userId')
#songs_in_day.show(n=10)
std_songs_in_month=std_songs_in_month\
.join(user_Churn,on='userId',how="inner")
std_songs_in_month.show(10)
plot_hist(df=std_songs_in_month,param='stddev_songs_per_month',title="stddev_songs per month ",y_axis="number of user",x_axis="stddev_songs per month")
songs=user_log\
.groupby("userId")\
.agg({'song':'count'})\
.withColumnRenamed('count(song)', 'songs heard so far') \
.orderBy("userId")
songs=songs\
.join(user_Churn,on='userId',how="inner")
songs.show(10)
plot_hist(df=songs,param='songs heard so far',title="Number of songs heard so far ",y_axis="number of user",x_axis="songs ")
Thumbs_Down=user_log\
.filter(user_log.page == 'Thumbs Down')\
.select('page','userId')\
.groupby("userId")\
.agg({'page':'count'})\
.withColumnRenamed('count(page)', 'Thumbs Down') \
.orderBy("userId")
Thumbs_Down=Thumbs_Down\
.join(user_Churn,on='userId',how="inner")
Thumbs_Down.show(n=10)
plot_hist(df=Thumbs_Down,param='Thumbs Down',title="Number of Thumbs Down ",y_axis="number of user",x_axis="Thumbs Down")
Thumbs_Up=user_log\
.filter(user_log.page == 'Thumbs Up')\
.select('page','userId')\
.groupby("userId")\
.agg({'page':'count'})\
.withColumnRenamed('count(page)', 'Thumbs Up') \
.orderBy("userId")
Thumbs_Up=Thumbs_Up\
.join(user_Churn,on='userId',how="inner")
Thumbs_Up.show(10)
plot_hist(df=Thumbs_Up,param='Thumbs Up',title="Number of Thumbs Up ",y_axis="number of user",x_axis="Thumbs Up")
playlist=user_log\
.filter(user_log.page == 'Add to Playlist')\
.select('page','userId')\
.groupby("userId")\
.agg({'page':'count'})\
.withColumnRenamed('count(page)', 'playlist') \
.orderBy("userId")
playlist=playlist\
.join(user_Churn,on='userId',how="inner")
playlist.show(10)
plot_hist(df=playlist,param='playlist',title="Number of playlist ",y_axis="number of user",x_axis="Playlist")
Friend=user_log\
.filter(user_log.page == 'Add Friend')\
.select('page','userId')\
.groupby("userId")\
.agg({'page':'count'})\
.withColumnRenamed('count(page)', 'Friend') \
.orderBy("userId")
Friend=Friend\
.join(user_Churn,on='userId',how="inner")
Friend.show(10)
plot_hist(df=Friend,param='Friend',title="Number of Friend ",y_axis="number of user",x_axis="Friend")
length=user_log\
.select('length','userId')\
.groupby("userId")\
.agg({'length':'sum'})\
.withColumnRenamed('sum(length)', 'total_lenght') \
.orderBy("userId")
length=length\
.join(user_Churn,on='userId',how="inner")
length.show(10)
plot_hist(df=length,param='total_lenght',title="total_lenght",y_axis="number of user",x_axis="total_lenght")
df_s=user_log\
.select('ts','userId','registration')\
.groupby("userId")\
.agg({'ts':'max','registration':'max'})\
.withColumnRenamed('max(ts)', 'ts') \
.withColumnRenamed('max(registration)', 'regi') \
.orderBy("userId")
#df_s.show(2)
regi_length=df_s.select(col('userId'),datediff(from_unixtime(col('ts')/1000),from_unixtime(col('regi')/1000)))\
.withColumnRenamed('datediff(from_unixtime((ts / 1000), yyyy-MM-dd HH:mm:ss), from_unixtime((regi / 1000), yyyy-MM-dd HH:mm:ss))', 'Day from registration') \
regi_length=regi_length\
.join(user_Churn,on='userId',how="inner")
length.show(10)
plot_hist(df=regi_length,param='Day from registration',title="Days from registration",y_axis="number of user",x_axis="Days from registration")
Once you've familiarized yourself with the data, build out the features you find promising to train your model on. To work with the full dataset, you can follow the following steps.
If you are working in the classroom workspace, you can just extract features based on the small subset of data contained here. Be sure to transfer over this work to the larger dataset when you work on your Spark cluster.
The original time series data is used as the future information for each user. Concatenate the data extracted above. Converts the target variable "Churn" to "Target"
result=session_song.select('userId','Churn','songs_per_session')\
.join(gender.drop('Churn'),on='userId',how="inner")\
.join(level.drop('Churn'),on='userId',how="inner")\
.join(songs_in_day.drop('Churn'),on='userId',how="inner")\
.join(std_songs_in_day.drop('Churn'),on='userId',how="inner")\
.join(songs_in_week.drop('Churn'),on='userId',how="inner")\
.join(std_songs_in_week.drop('Churn'),on='userId',how="inner")\
.join(songs_in_month.drop('Churn'),on='userId',how="inner")\
.join(std_songs_in_month.drop('Churn'),on='userId',how="inner")\
.join(songs.drop('Churn'),on='userId',how="inner")\
.join(Thumbs_Up.drop('Churn'),on='userId',how="inner")\
.join(Thumbs_Down.drop('Churn'),on='userId',how="inner")\
.join(playlist.drop('Churn'),on='userId',how="inner")\
.join(Friend.drop('Churn'),on='userId',how="inner")\
.join(length.drop('Churn'),on='userId',how="inner")\
.join(regi_length.drop('Churn'),on='userId',how="inner")\
.withColumnRenamed("Churn","target")
print(result.printSchema())
result.show(4)
Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.
result=result.drop('userId')
col=result.columns
col.remove('target')
col
train, test = result.randomSplit([0.7, 0.3], seed=0)
cols=result.columns
assembler = VectorAssembler(inputCols=cols, outputCol="features")
output = assembler.transform(result)
sdf2 = output.select("features")
from pyspark.ml.stat import Correlation
s_corr = Correlation.corr(sdf2, "features", "pearson").head()
#print("Pearson correlation matrix:\n" + str(s_corr[0]))
s_corr_ls = s_corr[0].toArray().tolist()
s_corr_df = spark.createDataFrame(s_corr_ls, cols)
p_corr_df = s_corr_df.select("*").toPandas()
r_index = pd.Series(cols)
p_corr_df = p_corr_df.set_index(r_index)
#print(p_corr_df)
import seaborn as sns
cm = sns.clustermap(p_corr_df, annot=True, figsize=(8, 8), col_cluster=False, row_cluster=False, fmt="1.1f")
cm.cax.set_visible(False)
display()
def build_model(classifier, param):
"""
pipeline for assembler,scaler and crossValidator
input:
classfier : estimator for CrossValidator
param : parameter for estimator
output:
pileline model
"""
assembler = VectorAssembler(inputCols=col, outputCol="features")
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
cv = CrossValidator(
estimator=classifier,
estimatorParamMaps=param,
evaluator=MulticlassClassificationEvaluator(labelCol='target', metricName='f1'),
numFolds=5,
)
model = Pipeline(stages=[assembler, scaler, cv])
return model
that it is unbalanced data set.F1 score is the most suitable evaluation metric.
def metrics(pred):
"""
print F1 score,Accuracy,Precision and Recall
input
pred: prediction after fit and transform
"""
evaluator = MulticlassClassificationEvaluator(labelCol="target", predictionCol="prediction" )
precision =evaluator.evaluate(pred, {evaluator.metricName:"weightedPrecision"})
recall =evaluator.evaluate(pred, {evaluator.metricName: "weightedRecall"})
f1 =evaluator.evaluate(pred, {evaluator.metricName: "f1"})
accuracy =evaluator.evaluate(pred, {evaluator.metricName: "accuracy"})
print("F1 : {}".format(f1))
print("Accuracy : {}".format(accuracy))
print("Precision: {}".format(precision))
print("Recall : {}".format(recall))
lr = LogisticRegression(featuresCol="scaled_features", labelCol="target")
param = ParamGridBuilder().build()
model = build_model(lr, param)
fit_model = model.fit(train)
pred = fit_model.transform(test)
#pred.select("prediction").dropDuplicates().collect()
best_model_lr=fit_model.stages[-1].bestModel
feature_coef = best_model_lr.coefficients
feature_coef_df = pd.DataFrame(list(zip(col, feature_coef)), columns=['Feature', 'Coefficient']).sort_values('Coefficient', ascending=False)
feature_coef_df.plot(kind='barh',x='Feature',y='Coefficient',legend=False)
plt.title('Feature importance for Logistic Rgression')
plt.xlabel('Coefficient')
metrics(pred)
gbt = GBTClassifier(labelCol="target", featuresCol="scaled_features")
param = ParamGridBuilder().addGrid(gbt.maxDepth, [2,6]).addGrid(gbt.maxBins, [10,20]).addGrid(gbt.maxIter, [5,10]).build()
print(param)
model = build_model(gbt, param)
gbt_model = model.fit(train)
pred_gbt = gbt_model.transform(test)
best_model_gbt=gbt_model.stages[-1].bestModel
featureImportances=best_model_gbt.featureImportances
feature_Importance_df = pd.DataFrame(list(zip(col, featureImportances)), columns=['Feature', 'featureImportances']).sort_values('featureImportances', ascending=True)
feature_Importance_df.plot(kind='barh',x='Feature',y='featureImportances',legend=False)
plt.title('Feature importance for GBTClassifier')
plt.xlabel('Feature importance')
best_model_gbt._java_obj.getMaxDepth()
best_model_gbt._java_obj.getMaxBins()
best_model_gbt._java_obj.getMaxIter()
metrics(pred_gbt)
rf = RandomForestClassifier(labelCol="target", featuresCol="scaled_features")
param = ParamGridBuilder().addGrid(rf.maxDepth, [2,6]).addGrid(rf.maxBins, [10,20]).addGrid(rf.numTrees, [5, 20, 50]).build()
model = build_model(rf, param)
rf_model = model.fit(train)
pred_rf = rf_model.transform(test)
best_model_rf=rf_model.stages[-1].bestModel
featureImportances=best_model_rf.featureImportances
feature_Importance_df = pd.DataFrame(list(zip(col, featureImportances)), columns=['Feature', 'featureImportances']).sort_values('featureImportances', ascending=True)
feature_Importance_df.plot(kind='barh',x='Feature',y='featureImportances',legend=False)
plt.title('Feature importance for Random Forest')
plt.xlabel('Feature importance')
best_model_rf._java_obj.getMaxDepth()
best_model_rf._java_obj.getNumTrees()
best_model_rf._java_obj.
best_model_rf.featureImportances.norm
metrics(pred_rf)
Model selection
Try the following three methods supported by PysparkML.
1.Logistic Regression
when the objective variable and the design variable have a linear relationship, it is an excellent method in terms of calculation cost and model readability. Check as a base model. On the other hand, it should be noted that the prediction accuracy deteriorates due to Multicol linearity when the variables are highly correlated.
2.Random Forest/3.GBT
Assuming the nonlinearity of the objective variable and the design variable, select a tree-based random forest boosting method that can support some readability. This method can also be expected to be sparse so that the classifier can extract valid variables.
Model improvement
1.Model tune
A grid-based search for maxDepth, maxBins, and maxIter parameters for a Tree-based model. The parameter with the best score was adopted.
2.robustness
After the train and test data were divided into 7:3, the train data was cross-validated five times and then the average score was adopted. This is an effective means for improving robustness when the amount of data is relatively small.
result
We conducted three models and obtained the following results. When conducting a campaign to Churn users, we would like to reduce the number of detections that Churn misses, while over-detection will distribute useless campaigns. The campaign can be conducted efficiently by referring to the F1 Score, which is the harmonic average of Recall and Precision.
| model | f1 score | accuracy | recall | Precision |
|---|---|---|---|---|
| Logistic Regression | 0.85 | 0.85 | 0.84 | 0.85 |
| GBTClassifier | 0.87 | 0.89 | 0.90 | 0.89 |
| Random forest | 0.75 | 0.81 | 0.76 | 0.81 |
Future considerations
Check the Featue Impottance for a GBT model. You can see that the following are variables that contributed to the classification accuracy.
conclusion
A classifier on a distributed platform was created by suppressing the characteristics of users who tend to cancel. Prediction with a linear model is a difficult event, and the result is that a Tree-based classifier is suitable.
We selected important Featurer, but in order to understand the actual market trends, it is necessary to use larger scale data, and we can imagine that the prediction accuracy and features may change. A large-scale distributed infrastructure is indispensable for developing services based on vast amounts of customer data.
Clean up your code, adding comments and renaming variables to make the code easier to read and maintain. Refer to the Spark Project Overview page and Data Scientist Capstone Project Rubric to make sure you are including all components of the capstone project and meet all expectations. Remember, this includes thorough documentation in a README file in a Github repository, as well as a web app or blog post.